{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2023-05-09T08:59:53.479916800Z",
     "start_time": "2023-05-09T08:59:53.087915800Z"
    }
   },
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'common'",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mModuleNotFoundError\u001B[0m                       Traceback (most recent call last)",
      "Cell \u001B[1;32mIn[3], line 2\u001B[0m\n\u001B[0;32m      1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[1;32m----> 2\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mcommon\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mfunctions\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m sigmoid\n",
      "\u001B[1;31mModuleNotFoundError\u001B[0m: No module named 'common'"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from common.functions import sigmoid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [],
   "source": [
    "class LSTM:\n",
    "    def __init__(self, Wx, Wh, b):\n",
    "        self.params = [Wx, Wh, b]\n",
    "        self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]\n",
    "        self.cache = None\n",
    "\n",
    "    def forward(self, x, h_prev, c_prev):\n",
    "        Wx, Wh, b = self.params\n",
    "        N, H = h_prev.shape\n",
    "\n",
    "        A = np.dot(x, Wx) + np.dot(h_prev, Wh) + b\n",
    "\n",
    "        # slice\n",
    "        f = A[:, :H]\n",
    "        g = A[:, H:2 * H]\n",
    "        i = A[:, 2 * H:3 * H]\n",
    "        o = A[:, 3 * H:]\n",
    "\n",
    "        f = sigmoid(f)\n",
    "        g = np.tanh(g)\n",
    "        i = sigmoid(i)\n",
    "        o = sigmoid(o)\n",
    "\n",
    "        c_next = f * c_prev + g * i\n",
    "        h_next = o * np.tanh(c_next)\n",
    "\n",
    "        self.cache = (x, h_prev, c_prev, i, f, g, o, c_next)\n",
    "        return h_next, c_next\n",
    "\n",
    "    def backward(self, dh_next, dc_next):\n",
    "        Wx, Wh, b = self.params\n",
    "        x, h_prev, c_prev, i, f, g, o, c_next = self.cache\n",
    "\n",
    "        tanh_c_next = np.tanh(c_next)\n",
    "\n",
    "        ds = dc_next + (dh_next * o) * (1 - tanh_c_next ** 2)\n",
    "\n",
    "        dc_prev = ds * f\n",
    "\n",
    "        di = ds * g\n",
    "        df = ds * c_prev\n",
    "        do = dh_next * tanh_c_next\n",
    "        dg = ds * i\n",
    "\n",
    "        di *= i * (1 - i)\n",
    "        df *= f * (1 - f)\n",
    "        do *= o * (1 - o)\n",
    "        dg *= (1 - g ** 2)\n",
    "\n",
    "        dA = np.hstack((df, dg, di, do))\n",
    "\n",
    "        dWh = np.dot(h_prev.T, dA)\n",
    "        dWx = np.dot(x.T, dA)\n",
    "        db = dA.sum(axis=0)\n",
    "\n",
    "        self.grads[0][...] = dWx\n",
    "        self.grads[1][...] = dWh\n",
    "        self.grads[2][...] = db\n",
    "\n",
    "        dx = np.dot(dA, Wx.T)\n",
    "        dh_prev = np.dot(dA, Wh.T)\n",
    "\n",
    "        return dx, dh_prev, dc_prev"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T09:06:26.829512900Z",
     "start_time": "2023-05-09T09:06:26.808447200Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [],
   "source": [
    "class TimeLSTM:\n",
    "    def __init__(self, Wx, Wh, b, stateful=False):\n",
    "        self.params = [Wx, Wh, b]\n",
    "        self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]\n",
    "        self.layers = None\n",
    "\n",
    "        self.h, self.c = None, None\n",
    "        self.dh = None\n",
    "        self.stateful = stateful\n",
    "\n",
    "    def forward(self, xs):\n",
    "        Wx, Wh, b = self.params\n",
    "        N, T, D = xs.shape\n",
    "        H = Wh.shape[0]\n",
    "\n",
    "        self.layers = []\n",
    "        hs = np.empty((N, T, H), dtype='f')\n",
    "\n",
    "        if not self.stateful or self.h is None:\n",
    "            self.h = np.zeros((N, H), dtype='f')\n",
    "        if not self.stateful or self.c is None:\n",
    "            self.c = np.zeros((N, H), dtype='f')\n",
    "\n",
    "        for t in range(T):\n",
    "            layer = LSTM(*self.params)\n",
    "            self.h, self.c = layer.forward(xs[:, t, :], self.h, self.c)\n",
    "            hs[:, t, :] = self.h\n",
    "            self.layers.append(layer)\n",
    "\n",
    "        return hs\n",
    "\n",
    "    def backward(self, dhs):\n",
    "        Wx, Wh, b = self.params\n",
    "        N, T, H = dhs.shape\n",
    "        D = Wx.shape[0]\n",
    "\n",
    "        dxs = np.empty((N, T, D), dtype='f')\n",
    "        dh, dc = 0, 0\n",
    "\n",
    "        grads = [0, 0, 0]\n",
    "        for t in reversed(range(T)):\n",
    "            layer = self.layers[t]\n",
    "            dx, dh, dc = layer.backward(dhs[:, t, :] + dh, dc)\n",
    "            dxs[:, t, :] = dx\n",
    "            for i, grad in enumerate(layer.grads):\n",
    "                grads[i] += grad\n",
    "\n",
    "        for i, grad in enumerate(grads):\n",
    "            self.grads[i][...] = grad\n",
    "            self.dh = dh\n",
    "            return dxs\n",
    "\n",
    "    def set_state(self, h, c=None):\n",
    "        self.h, self.c = h, c\n",
    "\n",
    "    def reset_state(self):\n",
    "        self.h, self.c = None, None\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T09:16:39.725059400Z",
     "start_time": "2023-05-09T09:16:39.710042200Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
